import torch
import torch.nn as nn
from functools import partial
import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
import numpy as np
import os

def embedding_forward(
        self,
        input_ids = None,    
        position_ids = None,
        name_batch = None,   
        inputs_embeds = None,
        embedding_manager = None,
        only_embedding=True,
        random_embeddings = None,
        timesteps = None,
    ) -> torch.Tensor:

        seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]

        if inputs_embeds is None:
            inputs_embeds = self.token_embedding(input_ids)
            if only_embedding:      
                return inputs_embeds
            
        if embedding_manager is not None:
            inputs_embeds, other_return_dict = embedding_manager(input_ids, inputs_embeds, name_batch, random_embeddings, timesteps)
        
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]   

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings
        
        return embeddings, other_return_dict


@torch.no_grad()
def _get_celeb_embeddings_basis(tokenizer, text_encoder, good_names_txt):
    
    device = text_encoder.device
    max_length = 77  
    
    with open(good_names_txt, "r") as f:
        celeb_names = f.read().splitlines()

    ''' get tokens and embeddings '''
    all_embeddings = []
    for name in celeb_names:
        batch_encoding = tokenizer(name, truncation=True, return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(device)[:, 1:3] 
        embeddings = text_encoder.text_model.embeddings(input_ids=tokens, only_embedding=True)  
        all_embeddings.append(embeddings)

    all_embeddings: torch.Tensor = torch.cat(all_embeddings, dim=0)
    
    
    print('[all_embeddings loaded] shape =', all_embeddings.shape,
            'max:', all_embeddings.max(),
            'min={}', all_embeddings.min())  
    
    name_emb_mean = all_embeddings.mean(0)
    name_emb_std = all_embeddings.std(0)    
    
    print('[name_emb_mean loaded] shape =', name_emb_mean.shape,
            'max:', name_emb_mean.max(),
            'min={}', name_emb_mean.min())  
    
    return name_emb_mean, name_emb_std

